# ~*~ coding: utf-8 ~*~
from ansible.inventory.host import Host
from ansible.vars.manager import VariableManager
from ansible.inventory.manager import InventoryManager
from ansible.parsing.dataloader import DataLoader


__all__ = [
    'BaseHost', 'BaseInventory'
]


class BaseHost(Host):
    def __init__(self, host_data):
        self.host_data = host_data
        hostname = host_data.get('hostname') or host_data.get('ip')
        port = host_data.get('port') or 22
        super().__init__(hostname, port)
        self.__set_required_variables()
        self.__set_extra_variables()

    def __set_required_variables(self):
        host_data = self.host_data
        self.set_variable('ansible_host', host_data['ip'])
        self.set_variable('ansible_port', host_data.get('port') or 22)

        if host_data.get('username'):
            self.set_variable('ansible_user', host_data['username'])

        # 添加密码和秘钥
        if host_data.get('password'):
            self.set_variable('ansible_ssh_pass', host_data['password'])
        if host_data.get('private_key'):
            self.set_variable('ansible_ssh_private_key_file', host_data['private_key'])

        # 添加become支持
        become = host_data.get("become", False)
        if become:
            self.set_variable("ansible_become", True)
            self.set_variable("ansible_become_method", become.get('method', 'sudo'))
            self.set_variable("ansible_become_user", become.get('user', 'root'))
            self.set_variable("ansible_become_pass", become.get('pass', ''))
        else:
            self.set_variable("ansible_become", False)

    def __set_extra_variables(self):
        for k, v in self.host_data.get('vars', {}).items():
            self.set_variable(k, v)

    def __repr__(self):
        return self.name


class BaseInventory(InventoryManager):
    """
    提供生成Ansible inventory对象的方法
    """
    loader_class = DataLoader
    variable_manager_class = VariableManager
    host_manager_class = BaseHost

    def __init__(self, host_list=None):
        """
        group_name: {
           "hosts": [
               {
                    "hostname": "",
                    "ip": "",
                    "port": "",
                    "username": "",
                    "password": "",
                    "private_key": "",
                    "become": {
                        "method": "",
                        "user": "",
                        "pass": "",
                    }
               }
           ]
       }
        :param host_list:
        """
        if host_list is None:
            host_list = {} 
        self.host_list = host_list
        #assert isinstance(host_list.get("hosts"), list)
        self.loader = self.loader_class()
        self.variable_manager = self.variable_manager_class()
        super().__init__(self.loader)

    def get_groups(self):
        return self._inventory.groups

    def get_group(self, name):
        return self._inventory.groups.get(name, None)

    def parse_sources(self, cache=False):
        group_all = self.get_group('all')
        # ungrouped = self.get_group('ungrouped')
        for k, v in self.host_list.items():
            host_l = v.get("hosts") 
            group = self.get_group(k)
            if group is None:
                self.add_group(k)
                group = self.get_group(k)
            print(group, k)
            for host_data in host_l:
                host = self.host_manager_class(host_data=host_data)
                self.hosts[host_data['hostname']] = host
                group.add_host(host)
                group_all.add_host(host)


    def get_matched_hosts(self, pattern):
        return self.get_hosts(pattern)
